import datetime

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from adjustText import adjust_text
from dateutil.relativedelta import relativedelta


params = {
    "axes.labelsize": 22,
    "font.size": 22,
    "legend.fontsize": 20,
    "xtick.labelsize": 20,
    "ytick.labelsize": 20,
    "grid.color": "k",
    "grid.linestyle": ":",
    "grid.linewidth": 1,
}
mpl.rcParams.update(params)

# download country ISO data
country_code = pd.read_csv(
    "https://raw.githubusercontent.com/lukes/ISO-3166-Countries-with-Regional-Codes/refs/heads/master/all/all.csv",
    na_filter=False,
)
dict_iso3_to_iso2 = dict(zip(country_code["alpha-3"], country_code["alpha-2"]))
dict_iso2_to_iso3 = dict(zip(country_code["alpha-2"], country_code["alpha-3"]))

estimate = pd.read_csv("international_migration_flow.csv", na_filter=False)
estimate_three_years_sum = (
    estimate.groupby("migration_month", as_index=False)["num_migrants"]
    .sum()
    .sort_values(by="migration_month")
    .reset_index()
)

df_2019_dest = (
    estimate[estimate["migration_month"].str.contains("2019")]
    .groupby("country_to", as_index=False)
    .agg({"num_migrants": "sum"})
)
df_2019_dest_monthly_avg = df_2019_dest.copy()
df_2019_dest_monthly_avg["num_migrants_monthly"] = (
    df_2019_dest_monthly_avg["num_migrants"] / 12
)
df_2020_2021_monthly_dest = (
    estimate[estimate["migration_month"].str.contains("2020|2021", regex=True)]
    .groupby(["country_to", "migration_month"], as_index=False)
    .agg({"num_migrants": "sum"})
)

migrants_change = df_2020_2021_monthly_dest.merge(
    df_2019_dest_monthly_avg[["country_to", "num_migrants_monthly"]],
    on="country_to",
    how="inner",
)
migrants_change["change"] = (
    migrants_change["num_migrants"] - migrants_change["num_migrants_monthly"]
) / migrants_change["num_migrants_monthly"]

migrants_change["change_scale"] = migrants_change["change"]
migrants_change.loc[migrants_change.change < -1, "change_scale"] = -1
migrants_change.loc[migrants_change.change > 1, "change_scale"] = 1
migrants_change_pivot = migrants_change.pivot(
    index="country_to", columns="migration_month", values="change_scale"
).reset_index()
migrants_change_pivot = migrants_change_pivot.set_index("country_to")

start = datetime.datetime.strptime("01-01-2020", "%d-%m-%Y")
end = datetime.datetime.strptime("01-05-2021", "%d-%m-%Y")
date_generated = [
    start + datetime.timedelta(days=x) for x in range(0, (end - start).days + 1)
]

date_range = []
for date in date_generated:
    date_range.append(date.strftime("%Y-%m-%d"))

month_generated = [start + relativedelta(months=x) for x in range(17)]
month_range = []
for month in month_generated:
    month_range.append(month.strftime("%Y-%m-%d"))

m_index = []
for m in month_range:
    m_index.append(date_range.index(m))

month_df = pd.DataFrame({"date": date_range})
colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#e6ab02"]
sample_country = ["NE", "HT", "FR", "SK", "AU", "AR"]
color_dict = dict(zip(sample_country, colors))

country_dict = dict(zip(country_code["alpha-2"], country_code["name"]))

# stringency
# downlod data from https://github.com/owid/covid-19-data/blob/master/public/data/owid-covid-data.csv
df_stringency = pd.read_csv("owid-covid-data.csv")
daily_stringency_index = (
    df_stringency.groupby("date", as_index=False)["stringency_index"]
    .median()
    .sort_values(by="date")
)
stringency_global = daily_stringency_index[daily_stringency_index.date <= "2021-12-31"]

daily_stringency_index = (
    df_stringency.groupby(["date", "iso_code"], as_index=False)["stringency_index"]
    .median()
    .sort_values(by="date")
)
stringency_per_country = daily_stringency_index[
    daily_stringency_index.date <= "2021-12-31"
]

colors = ["#1b9e77", "#d95f02", "#4f4a8c", "#e7298a", "#66a61e", "#e6ab02"]
sample_country = ["NE", "HT", "FR", "SK", "AU", "AR"]
color_dict = dict(zip(sample_country, colors))


def plot_one_country_change(ax, iso2, _df):
    _df2 = pd.DataFrame({"date": month_range, "v": _df.loc[iso2, :].values.tolist()})
    _df2 = month_df.merge(_df2, on="date", how="left")
    xs = _df2.v
    s1mask = np.isfinite(xs)
    ax.plot(
        np.arange(len(xs))[s1mask],
        list(xs[s1mask]),
        "-o",
        c=color_dict[iso2],
        label=country_dict[iso2],
        zorder=2,
    )
    ax.set_ylim(-1, 1)


def plot_one_country_stringency(ax, iso2):
    _df = stringency_per_country[
        stringency_per_country.iso_code == dict_iso2_to_iso3[iso2]
    ]
    _df = month_df.merge(_df, on="date", how="left")
    xs = _df.stringency_index
    s1mask = np.isfinite(xs)
    ax.plot(
        np.arange(len(xs))[s1mask],
        list(xs[s1mask]),
        "--",
        c=color_dict[iso2],
        alpha=0.5,
        zorder=3,
    )
    ax.set_ylim(0, 100)


def plot_two_countries(ax, c1, c2, _df):
    # tune the location for NE
    if c1 == "NE":
        offset = -0.05
        l = 5
        r = l + 202
        t = -0.3 + offset
        b = -0.9 + offset
        ax.plot([l, r, r, l, l], [b, b, t, t, b], c="lightgray", linewidth=1, zorder=1)
        ax.plot([10, 33], [-0.45 + offset, -0.45 + offset], "-", c="k")
        ax.plot([10, 33], [-0.7 + offset, -0.7 + offset], "--", c="k")
        ax.text(42, -0.5 + offset, "Migration change", fontsize=12)
        ax.text(42, -0.77 + offset, "COVID-19 policy stringency index", fontsize=12)
    plot_one_country_change(ax, c1, _df)
    plot_one_country_change(ax, c2, _df)
    ax.yaxis.set_major_formatter(mpl.ticker.StrMethodFormatter("{x:,.0f}"))
    ax_right = ax.twinx()
    plot_one_country_stringency(ax_right, c1)
    plot_one_country_stringency(ax_right, c2)
    plt.xlim(0, len(stringency_global) - 1)
    ax.axhline(y=0, linestyle="--", color="gray")
    if c1 != "AU":
        plt.xticks(m_index, [])
    else:
        ax.set_xticks(m_index)
        ax.set_xticklabels(
            [m[:7] for m in month_range],
            rotation=45,
            ha="right",
            rotation_mode="anchor",
        )
    if c1 == "NE":
        leg = ax.legend(loc=1, fontsize=12, fancybox=False)
    elif c1 == "FR":
        leg = ax.legend(loc=4, fontsize=12, fancybox=False)
    else:
        leg = ax.legend(fontsize=12, fancybox=False)
    leg.get_frame().set_edgecolor("lightgray")


# download population data from https://data.worldbank.org/indicator/SP.POP.TOTL
population = pd.read_csv("world_population.csv", na_filter=False)
population = population[["iso3", "popu_2019"]]
population = population[population["popu_2019"] != ""]
population["popu_2019"] = population["popu_2019"].astype(int)
# add Taiwan
taiwan_population = pd.DataFrame({"iso3": ["TWN"], "popu_2019": [23e6]})
population = pd.concat([population, taiwan_population])
population = population[population.iso3.isin(dict_iso3_to_iso2.keys())]
population["iso2"] = population.apply(lambda x: dict_iso3_to_iso2[x["iso3"]], axis=1)

migrants_change_pivot = migrants_change_pivot.drop(
    columns=[
        "2021-06",
        "2021-07",
        "2021-08",
        "2021-09",
        "2021-10",
        "2021-11",
        "2021-12",
    ]
)
# only show those countries with at least 2 million people to avoid overlapping
migrants_change_pivot = migrants_change_pivot[
    migrants_change_pivot.index.isin(population[population["popu_2019"] > 2e6]["iso2"])
]

stringency_global = stringency_global[stringency_global["date"] <= "2021-05-01"]
dates = list(stringency_global["date"])
month_index = []
_estimate_three_years_sum = estimate_three_years_sum[
    (estimate_three_years_sum["migration_month"] <= "2021-05")
    & (estimate_three_years_sum["migration_month"] >= "2020-01")
]
for i in _estimate_three_years_sum["migration_month"]:
    month_index.append(dates.index(i + "-01"))

fig = plt.figure()
fig.tight_layout()

fig.set_size_inches(24, 12)
fig.subplots_adjust(wspace=0.12, hspace=0.2)

ax1 = fig.add_subplot(1, 2, 1)

plt.scatter(
    migrants_change_pivot["2020-10"], migrants_change_pivot["2021-04"], s=10, c="gray"
)
i = 0
texts = []
for c, row in migrants_change_pivot.iterrows():
    if c in sample_country and c != "AR":
        plt.text(
            row["2020-10"] + 0.01,
            row["2021-04"] - 0.01,
            dict_iso2_to_iso3[c],
            fontsize=12,
            color=color_dict[c],
        )
    elif c == "AR":
        plt.text(
            row["2020-10"] - 0.05,
            row["2021-04"] + 0.01,
            dict_iso2_to_iso3[c],
            fontsize=12,
            color=color_dict[c],
        )
    else:
        texts.append(
            plt.text(
                row["2020-10"],
                row["2021-04"],
                dict_iso2_to_iso3[c],
                fontsize=12,
                color="k",
            )
        )
    i += 1

adjust_text(texts)

plt.axvline(x=0, linestyle="--", color="gray")
plt.axhline(y=0, linestyle="--", color="gray")
plt.text(
    0.5,
    0.6,
    "Type 1",
    fontweight="bold",
    fontsize=15,
    bbox=dict(facecolor="none", edgecolor="k", pad=5.0),
)
plt.text(
    0.5,
    -0.82,
    "Type 2",
    fontweight="bold",
    fontsize=15,
    bbox=dict(facecolor="none", edgecolor="k", pad=5.0),
)
plt.text(
    -0.5,
    -0.82,
    "Type 3",
    fontweight="bold",
    fontsize=15,
    bbox=dict(facecolor="none", edgecolor="k", pad=5.0),
)
plt.xlabel("Migration change in Oct 2020 relative to 2019")
plt.ylabel("Migration change in Apr 2021 relative to 2019")
plt.text(-1.15, 0.67, "(A)", fontsize=30)

ax2 = fig.add_subplot(4, 2, 2)

ax22 = ax2.twinx()
stringency_global = stringency_global[stringency_global["date"] <= "2021-05-01"]
ax2.plot(month_index, list(_estimate_three_years_sum["num_migrants"] / 1e6), "-", c="k")
ax22.plot(
    list(range(len(stringency_global))),
    list(stringency_global["stringency_index"]),
    "--",
    c="k",
    label="COVID-19 policy stringency index",
)

ax2.set_ylabel("# Migrants (M)", color="k")
ax2.set_ylim(0, 4)
ax22.set_ylabel("COVID-19 policy stringency index", color="k")
ax22.yaxis.set_label_coords(1.08, -1.4)
offset = 0
l = 280
r = l + 202
t = 1.8 + offset
b = 0.2 + offset
ax2.plot([l, r, r, l, l], [b, b, t, t, b], c="lightgray", linewidth=1)
ax2.plot([l + 7, l + 30], [1.3 + offset, 1.3 + offset], "-", c="k")
ax2.plot([l + 7, l + 30], [0.6 + offset, 0.6 + offset], "--", c="k")
ax2.text(l + 35, 1.2 + offset, "# Migrants (Millions)", fontsize=12)
ax2.text(l + 35, 0.5 + offset, "COVID-19 policy stringency index", fontsize=12)


ax22.set_ylim(0, 100)
plt.xlim(0, len(stringency_global) - 1)
plt.title("Global trends", fontsize=15)
plt.xticks(m_index, [])
plt.text(5, 78, "(B)", fontsize=30)

ax3 = fig.add_subplot(4, 2, 4)
plot_two_countries(ax3, sample_country[0], sample_country[1], migrants_change_pivot)
plt.title("Type 1", fontsize=15)
plt.text(5, 78, "(C)", fontsize=30)

ax4 = fig.add_subplot(4, 2, 6)
plot_two_countries(ax4, sample_country[2], sample_country[3], migrants_change_pivot)
plt.title("Type 2", fontsize=15)
ax4.set_ylabel("Migration change relative to 2019")
plt.text(5, 78, "(D)", fontsize=30)

ax5 = fig.add_subplot(4, 2, 8)
plot_two_countries(ax5, sample_country[4], sample_country[5], migrants_change_pivot)
plt.title("Type 3", fontsize=15)
plt.text(5, 78, "(E)", fontsize=30)

plt.savefig(
    "figures/fig3.pdf",
    facecolor="white",
    transparent=False,
    bbox_inches="tight",
    dpi=300,
)
